import astropy.io.fits as fits
import math
import numpy as np
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
import astropy.units as u
import pickle
from scipy import interpolate
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
c=3.0e10

def rms(x, axis=None):
    return np.sqrt(np.mean(x**2, axis=axis))


#spectral_template_dict=pickle.load(open("spectral_template_dict.pickle", "rb"))

v_src=0.0
template_trans={
                'line1' : {'freq' : 203.40752e9, 'A_ul': 4.812E-06,'E_u': 203.7, 'g_u':7 , 'trans': ''}
               }

def get_spectrum(imagename,coord,radius,restfreq):
   image=fits.open(imagename)
   header=image[0].header
   data_copy=image[0].data.copy()
   if header['CUNIT3'] == 'M/S':   
      header['CUNIT3'] = 'm/s'
      cdelt3=header['CDELT3']
      crpix3=header['CRPIX3']
      crval3=header['CRVAL3']
      dnu=cdelt3/1000.0/3.0e5*restfreq
      header['CDELT3']=dnu
      header['CRVAL3']=restfreq
      header['CUNIT3']='Hz'
      header['CTYPE3']='FREQ'
      header['SPECSYS']='LSRK'
   if header['NAXIS']==4:
      header['NAXIS']=3
      header.pop('CTYPE4')
      header.pop('CDELT4')
      header.pop('CRPIX4')
      header.pop('CUNIT4')
      header.pop('CRVAL4')
      data_copy=np.squeeze(data_copy)
   wcs=WCS(header)
   xy=wcs.world_to_pixel(coord,225.0*u.Hz)
   x=xy[0][0]
   y=xy[0][0]
   radius_px=radius/(header['CDELT2']*3600.0)
   nx=header['NAXIS1']
   ny=header['NAXIS2'] 
   nchans=header['NAXIS3']
   freq_axis=np.zeros(nchans)
   spectrum=np.zeros(nchans)
   e_spectrum=np.zeros(nchans)
   spectrum_beam=np.zeros(nchans)
   e_spectrum_beam=np.zeros(nchans)
   mask,inverse_mask=create_circular_mask2(nx,ny, center=[x,y], radius=radius_px)
   mask=mask.transpose()
   inverse_mask=inverse_mask.transpose()
   #npix_beam=3.14159*header['BMAJ']*header['BMIN']/(4.0*math.log(2.0))/header['CDELT2']**2
   #nbeam_per_spectrum=3.14159*radius_px**2/npix_beam
   for i in range(nchans):
      freq_axis[i]=header['CRVAL3']+ (i-header['CRPIX3'])*header['CDELT3']
      spectrum[i]=np.sum(data_copy[i,:,:]*mask) #/npix_beam)
      spectrum_beam[i]=np.sum(data_copy[i,:,:]*mask)
      indices=(inverse_mask==1).nonzero()
      e_spectrum[i]=rms(data_copy[i,indices[0],indices[1]])#*nbeam_per_spectrum**0.5/npix_beam
      e_spectrum_beam[i]=rms(data_copy[i,indices[0],indices[1]])#*nbeam_per_spectrum**0.5
      data_copy[i,:,:]=data_copy[i,:,:]*mask
   #image.writeto(imagename.replace('fits','masked.fits'))
   return freq_axis, spectrum,e_spectrum,spectrum_beam,e_spectrum_beam #,nbeam_per_spectrum,npix_beam


def create_circular_mask(h, w, center=None, radius=None):

    if center is None: # use the middle of the image
        center = (int(w/2), int(h/2))
    if radius is None: # use the smallest distance between the center and image walls
        radius = min(center[0], center[1], w-center[0], h-center[1])

    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)

    mask = dist_from_center <= radius
    return mask

def create_circular_mask2(nx, ny, center=None, radius=None):
   x = np.linspace(0,nx-1, nx)
   y = np.linspace(0,ny-1, ny)
   x, y = np.meshgrid(x, y)
   mask = np.sqrt((x - center[0])**2 + (y-center[1])**2)
   inverse_mask = np.sqrt((x - center[0])**2 + (y-center[1])**2)
   for x in range(0,nx):
        for y in range(0,ny):
                if mask[x,y] < radius:
                        mask[x,y] = 1.00
                        inverse_mask[x,y] = 0
                elif mask[x,y] >= radius:
                        mask[x,y] = 0
                        inverse_mask[x,y] = 1.0
   return mask,inverse_mask

radecstring='00h00m00.0s 00d00m00.00s'
coord=SkyCoord([radecstring],frame='fk5',unit=(u.hourangle,u.deg))

spectral_template_dict={}
spectral_template_dict['H218O_template']={'filename': 'H218O-i38.3-203GHz.fits', 'freq' : 203.40752e9}


for key in spectral_template_dict.keys():
   print(key)
   spectral_template_dict[key]['freq_axis'],spectral_template_dict[key]['spectrum'],\
        spectral_template_dict[key]['e_spectrum'],spectral_template_dict[key]['spectrum_Jy_beam'],\
        spectral_template_dict[key]['e_spectrum_Jy_beam']=get_spectrum(spectral_template_dict[key]['filename'],\
                                                                                      coord,0.4,spectral_template_dict[key]['freq'])



for key in spectral_template_dict.keys():
   fig,ax=plt.subplots(1,1,figsize=(15, 6))
   ax.plot(spectral_template_dict[key]['freq_axis']/1e9,spectral_template_dict[key]['spectrum'],color='black',drawstyle='steps',linewidth='1')
   ax.xaxis.set_major_locator(MultipleLocator(0.005))
   ax.xaxis.set_minor_locator(AutoMinorLocator())
   ax.tick_params(which='both', width=2)
   ax.tick_params(which='major', length=8)
   ax.tick_params(which='minor', length=4)
   ax.set_title(key)
   ax.set_ylabel('Flux Density (Jy)')
   ax.set_xlabel('Frequency (GHz)')
   ax.set_xlim(np.min(spectral_template_dict[key]['freq_axis']/1e9),np.max(spectral_template_dict[key]['freq_axis']/1e9))
   plt.savefig(key+'_spectrum.png')





#extract sub spectra
for key in spectral_template_dict.keys():
   rel_freq=spectral_template_dict[key]['freq_axis']-spectral_template_dict[key]['freq']
   spectral_template_dict[key]['rel_freq_axis']=spectral_template_dict[key]['freq_axis']-spectral_template_dict[key]['freq']


   fig,ax=plt.subplots(1,1,figsize=(15, 6))
   ax.plot(spectral_template_dict[key]['rel_freq_axis']/1e9,spectral_template_dict[key]['spectrum'],drawstyle='steps',linewidth='1')
   plt.savefig(key+'_spectrum.png')
   spectral_template_dict[key]['spectrum']=spectral_template_dict[key]['spectrum']/np.max(spectral_template_dict[key]['spectrum'])

   new_avg_spectrum_wflip=(spectral_template_dict[key]['spectrum']+np.flip(spectral_template_dict[key]['spectrum']))/2.0

   new_avg_spectrum_wflip=new_avg_spectrum_wflip/np.max(new_avg_spectrum_wflip)

   fig,ax=plt.subplots(1,1,figsize=(15, 6))
   ax.plot(spectral_template_dict[key]['rel_freq_axis']/1e9,spectral_template_dict[key]['spectrum'],drawstyle='steps',linewidth='1')
   ax.plot(spectral_template_dict[key]['rel_freq_axis']/1e9,np.flip(spectral_template_dict[key]['spectrum']),drawstyle='steps',linewidth='1')
   ax.plot(spectral_template_dict[key]['rel_freq_axis']/1e9,new_avg_spectrum_wflip,drawstyle='steps',linewidth='1')
   plt.savefig('flipped_'+key+'_spectra.png')


spectral_template_lime={'freq_axis': spectral_template_dict[key]['rel_freq_axis'].copy(),'spectrum': new_avg_spectrum_wflip.copy() }


pickle.dump(spectral_template_lime, open("spectral_template_lime.pickle", "wb"))









